import os
from typing import Annotated

from langchain_community.chat_models import ChatZhipuAI
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import AIMessage, ToolMessage
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import START, StateGraph, END
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
from pydantic import BaseModel
from typing_extensions import TypedDict

os.environ["ZHIPUAI_API_KEY"] = "97738d4998b8732d707daf91a2b1c56d.2y6VKEuOlidwHDpI"
os.environ["TAVILY_API_KEY"] = "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"
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_API_KEY"] = "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"


class RequestAssistance(BaseModel):
    """Escalate the conversation to an expert. Use this if you are unable to assist directly or if the user requires support beyond your permissions.
    To use this function, relay the user's 'request' so the expert can provide the right guidance.
    """
    request: str


tool = TavilySearchResults(max_results=2)
tool_node = ToolNode(tools=[tool])
tools = [tool]

memory = MemorySaver()


class State(TypedDict):
    messages: Annotated[list, add_messages]
    ask_human: bool


graph_builder = StateGraph(State)

llm = ChatZhipuAI(
    model="glm-4",
    temperature=0.5,
).bind_tools(tools + [RequestAssistance])


def chatbot(state: State):
    response = llm.invoke(state["messages"])
    ask_human = False
    if (response.tool_calls and response.tool_calls[0]['name'] == RequestAssistance.__name__):
        ask_human = True
    return {"messages": [response], "ask_human": ask_human}


def create_response(response: str, ai_message: AIMessage):
    return ToolMessage(
        content=response,
        tool_call_id=ai_message.tool_calls[0]["id"]
    )


def human_node(state: State):
    new_messages = []
    if not isinstance(state["messages"][-1], ToolMessage):
        new_messages.append(
            create_response("No response from human.", state["messages"][-1])
        )
    return {
        "messages": new_messages,
        "ask_human": False
    }


def select_next_node(state: State):
    if state["ask_human"]:
        return "human"
    return tools_condition(state)


graph_builder.add_node("chatbot", chatbot)
graph_builder.add_node("tools", tool_node)
graph_builder.add_node("human", human_node)
graph_builder.add_edge(START, "chatbot")
graph_builder.add_edge("tools", "chatbot")
graph_builder.add_edge("human", "chatbot")
graph_builder.add_conditional_edges(
    "chatbot",
    select_next_node,
    {"human": "human", "tools": "tools", END: END},
)

graph = graph_builder.compile(
    checkpointer=memory,
    interrupt_before=["human"],
)

user_input = "I need some expert guidance for building this AI agent. Could you request assistance for me?"
config = {"configurable": {"thread_id": "1"}}
events = graph.stream(
    {"messages": [("user", user_input)]},
    config,
    stream_mode="values"
)
for event in events:
    if "messages" in event:
        event["messages"][-1].pretty_print()

snapshot = graph.get_state(config)
print(snapshot.next)
ai_message = snapshot.values["messages"][-1]
human_response = (
    "We, the experts are here to help! We'd recommend you check out LangGraph to build your agent."
    " It's much more reliable and extensible than simple autonomous agents."
)
tool_message = create_response(human_response, ai_message)
graph.update_state(
    config,
    {"messages": tool_message},
)

# for message in graph.get_state(config).values["messages"]:
#     message.pretty_print()
# print(graph.get_state(config).values["messages"][-2:])

events = graph.stream(None, config, stream_mode="values")
for event in events:
    if "messages" in event:
        event["messages"][-1].pretty_print()